package comet

import (
	acsclient "acs/comet/client"
	"acs/comet/config"
	"acs/comet/hook"
	"acs/comet/proto"
	"acs/pbmodel"
	"acs/util"
	"encoding/binary"
	"errors"
	"fmt"
	"net"
	"os"
	"os/signal"
	"runtime/debug"
	"sync"
	"sync/atomic"
	"syscall"
	"time"

	log "github.com/cihub/seelog"
	pb "github.com/golang/protobuf/proto"
	"go.uber.org/zap"

	"github.com/sunreaver/grace/gracenet"
	logData "github.com/sunreaver/logger"
)

const (
	Second = int64(time.Second)
)

var (
	exitWG     *sync.WaitGroup
	clientList *acsclient.ClientList
	logger     = log.Default
	loggerData *zap.SugaredLogger
	setMutex   = new(sync.Mutex)
	Conf       *config.Config

	globalHook = hook.EmptyRegisterHook
)

var (
	ErrParsePacketLength               = errors.New("Packet length parse error")
	ErrParsePacketTransactionIdAndType = errors.New("Packet transaction ID and type parse error")
	ErrPacketLenExcess                 error
	ErrPacketInvalidType               = errors.New("Invalid packet type.")
)

var appConnectionCounter int32

type PacketResult struct {
	// Tsid 事务ID
	Tsid    uint16
	Packet  *proto.Packet
	err     error
	errCode int32
}

func netIoTimeout() *time.Time {
	t := time.Now().Add(time.Duration(Conf.MaxTansDataIoTimeout) * time.Millisecond)
	return &t
}

// SetLogger sets the logger for taskprocess package.
// seelog.Default it the default logger.
func SetLogger(newLogger log.LoggerInterface) {
	setMutex.Lock()
	defer setMutex.Unlock()
	logger = newLogger
}

func StartAppConfigServer(addr string, clients *acsclient.ClientList, exitWaitGroup *sync.WaitGroup) error {
	loggerData = logData.GetSugarLogger("app_config_server.log")

	Conf = config.Conf
	ErrPacketLenExcess = fmt.Errorf("packet length exceeded %v", Conf.MaxProtoSize)
	clientList = clients
	globalHook = hook.NewHook(Conf.Hook)
	exitWG = exitWaitGroup
	exitWG.Add(1)
	go func(addr string) {
		defer func() {
			exitWG.Done()
			logger.Flush()
		}()

		tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
		if err != nil {
			logger.Error(err)
			panic(err)
		}

		// 使用gracenet
		logger.Infof("starting tcp at %v", addr)
		tl, err := gnet.ListenTCP("tcp4", tcpAddr)
		if err != nil {
			logger.Error(err)
			panic(err)
		}
		defer tl.Close()
		connChan := make(chan *net.TCPConn)
		go acceptTCP(tl, connChan)

		// 等待重启信号
		go waitRestart(gnet)
		for {
			select {
			case <-exitService:
				return
			case conn := <-connChan:
				if conn == nil {
					// accept 出错退出
					return
				}
				tc := atomic.LoadInt32(&appConnectionCounter)
				if tc >= Conf.MaxAppConnections {
					logger.Warnf("maxium client connections [%v] recached[%v]...reject %v.", Conf.MaxAppConnections, tc, conn.RemoteAddr())
					rejectMoreConnections(conn)
					conn.Close()
					continue
				}
				go HandleTCPConn(conn)
			}
		}
	}(addr)
	return nil
}

func waitRestart(gnet *gracenet.Net) {
	c := make(chan os.Signal, 1)
	signal.Notify(c, syscall.SIGHUP)

	for {
		<-c
		pid, err := gnet.StartProcess()
		if err != nil {
			logger.Errorf("[grace]StartProcess error: %s.", err.Error())
			continue
		}
		logger.Infof("[grace]StartProcess New [pid]: %d.", pid)
		signal.Stop(c)

		// 关闭service
		close(exitService)
		break
	}

}

func rejectMoreConnections(conn net.Conn) {
	errMsg, err := pb.Marshal(&pbmodel.ProtoErr{Code: pb.Int(1001), Msg: pb.String(fmt.Sprintf("Maxium connections reached[%v].", Conf.MaxAppConnections))})
	if err != nil {
		logger.Warnf("error message marshal error: %v", err)
		return
	}
	conn.SetDeadline(time.Now().Add(time.Second * 1))
	_, err = conn.Write(proto.NewRequest(0, "conn_reject", errMsg, Conf.DataEncryptKey, Conf.DataEncryptIV).Encode())
	if err != nil {
		logger.Warnf("reject app connection error: %v", err)
	}
}

func acceptTCP(tl *net.TCPListener, ch chan *net.TCPConn) {
	for {
		conn, err := tl.AcceptTCP()
		if err != nil {
			select {
			case <-exitService:
				return
			default:
				logger.Warn(err)
			}
			ch <- nil
		} else {
			ch <- conn
		}
	}
}

// HandleTCPConn 处理可兑换连接。 遇到所有连接或协议相关的错误,都应当关闭连接.
func HandleTCPConn(conn *net.TCPConn) {
	atomic.AddInt32(&appConnectionCounter, 1)
	exitWG.Add(1)
	var client *acsclient.Client
	defer func() {
		atomic.AddInt32(&appConnectionCounter, -1)
		err := recover()
		if err != nil {
			util.PrintPanicStack()
			logger.Errorf("client handler routine exited unexpectedly - RemoteAddr: %v, LocalAddr: %v [ %v ]: %s", conn.RemoteAddr(), conn.LocalAddr(), err, debug.Stack())
		}
		exitWG.Done()
	}()
	logger.Infof("Accepted connection from %v", conn.RemoteAddr().String())
	err := conn.SetKeepAlive(true)
	if err != nil {
		logger.Warn(err)
		conn.Close()
		return
	}
	client = acsclient.NewClient(&acsclient.ClientConfig{
		Conn:              conn,
		HandshakeTimeout:  Conf.ClientHandshakeTimeout,
		ExitWaitGroup:     exitWG,
		MaxReqRespTimeout: Conf.MaxReqRespTimeout,
		DataEncryptKey:    Conf.DataEncryptKey,
		DataEncryptIV:     Conf.DataEncryptIV,
	})
	client.Hooks = globalHook

	ele, _ := clientList.AddClient(conn.RemoteAddr().String(), client)
	packetResultChan := make(chan PacketResult, 1)
	emptyResult := PacketResult{}
	readerProcessingLock := new(sync.Mutex)
	readerExitChan := make(chan bool)
	go readPacketFromClient(client, packetResultChan, readerProcessingLock, readerExitChan)

	// 设置对端空闲超时时间.
	tIdle := time.Millisecond * time.Duration(Conf.HeartBeatTimeout)
	var clientTimeout *time.Timer
	clientTimeout = time.NewTimer(tIdle)
	defer clientTimeout.Stop()
LOOPHDL:
	for {
		var packetResult PacketResult
		select {
		// 对端响应超时或长时间无通信
		case <-clientTimeout.C:
			logger.Warnf("Client idled or response timedout for %v, closing...", tIdle)
			break LOOPHDL

		case packetResult = <-packetResultChan:
			clientTimeout.Reset(tIdle)

		// 服务正常退出
		case <-exitService:
			break LOOPHDL

		case err = <-client.ConnErr:
			logger.Warnf("client connection broken: %v", err)
			// 防止业务处理时返回非连接相关的错误.
			break LOOPHDL
		}

		// 数据读读取或解析出错时,由于协议上的限制,无法再正常通讯, 应当退出连接上的服务。
		if packetResult.err != nil {
			client.HandleProtoError(packetResult.Tsid, packetResult.errCode, packetResult.err)
			logger.Warn(packetResult.err)
			break LOOPHDL
		}
		// reader 已经退出
		if packetResult != emptyResult {
			processPacketResult(client, &packetResult)
		} else {
			break LOOPHDL
		}

	}

	close(readerExitChan)
	// 处理尚未处理的包
	readerProcessingLock.Lock()
	defer readerProcessingLock.Unlock()
	var readerExitTimeout *time.Timer
	tWaitTimeout := time.Millisecond * 20
	readerExitTimeout = time.NewTimer(tWaitTimeout)
	defer readerExitTimeout.Stop()
WaitReader:
	for {
		select {
		case <-readerExitTimeout.C:
			break WaitReader
		case packetResult := <-packetResultChan:
			readerExitTimeout.Reset(tWaitTimeout)
			if packetResult == emptyResult {
				break WaitReader
			}
			if packetResult.err != nil {
				client.HandleProtoError(packetResult.Tsid, packetResult.errCode, packetResult.err)
				logger.Warn(packetResult.err)
			} else {
				processPacketResult(client, &packetResult)
			}
		}
	}

	clientList.RemoveClient(conn.RemoteAddr().String(), ele)
	client.Close()
}

func readPacketFromClient(client *acsclient.Client, packetResultChan chan PacketResult, readerProcessingLock *sync.Mutex, readerExitChan chan bool) {
	var (
		packet       *proto.Packet
		err          error
		transID      uint16
		statusLocked bool
	)
	defer func() {
		errI := recover()
		if errI != nil {
			packetResultChan <- PacketResult{
				err:     fmt.Errorf("System error: %v", errI),
				errCode: 400,
				Tsid:    transID,
			}
		} else if err != nil {
			packetResultChan <- PacketResult{
				err:     fmt.Errorf("client data read error: %v", err),
				errCode: 401,
				Tsid:    transID,
			}
		}
		close(packetResultChan)
		if statusLocked {
			readerProcessingLock.Unlock()
			statusLocked = false
		}
	}()

LOOPClientR:
	for {
		if err != nil {
			logger.Warnf("client connection broken: %v", err)
			break
		}

		select {
		case <-readerExitChan:
			break LOOPClientR
		default:
		}

		packetResult := PacketResult{}
		var (
			count   int
			lenBuf  []byte
			tsbuf   []byte
			cmdBuf  []byte
			dataBuf []byte
		)
		lenBuf = make([]byte, proto.HeadLength)
		// 解析长度，4字节
		offset := 0
		for offset < proto.HeadLength {
			client.Conn.SetReadDeadline(*netIoTimeout())
			count, err = client.Conn.Read(lenBuf[offset:])
			if err != nil {
				nErr, ok := err.(net.Error)
				if !ok {
					break LOOPClientR
				} else {
					if nErr.Timeout() {
						if count == 0 && offset == 0 {
							// 本次未读到任何新的事务数据包, 进入下一轮流程.
							err = nil
							continue LOOPClientR
						} else if count > 0 {
							// 读到了部分长度数据，且刚好将超时时间周期用完,继续读取剩下的长度数据.
							offset += count
							continue
						} else { // 读取剩余部分长度数据超时，则认为连接有问题，不可再用.
							break LOOPClientR
						}
					} else {
						break LOOPClientR
					}
				}
			} else {
				offset += count
			}
		}

		// 如果已经读取到包的开头，则应当让请求被处理完.
		readerProcessingLock.Lock()
		statusLocked = true
		packetLen := binary.BigEndian.Uint32(lenBuf)
		// TODO: 版本号暂时无用
		verB := make([]byte, proto.VerLength)
		client.Conn.SetReadDeadline(*netIoTimeout())
		count, err = client.Conn.Read(verB)
		if err != nil {
			err = fmt.Errorf("client connection read error where read protocol ver： %v", err)
			break
		} else if count != proto.VerLength {
			packetResult.err = err
			packetResult.errCode = 16
			packetResultChan <- packetResult
			break
		} else {
			ver := uint8(verB[0])
			logger.Debugf("Got client protocol version %v", ver)
		}
		tsbuf = make([]byte, proto.TsidTypeLen)
		// 解析事务号和消息类型，2字节
		client.Conn.SetReadDeadline(*netIoTimeout())
		if count, err = client.Conn.Read(tsbuf); err != nil || count < proto.TsidTypeLen {
			if err != nil {
				logger.Warnf("client connection read error where read transaction type and no.: %v", err)
			} else {
				logger.Warnf("TransactionTypeAndId data length is [%v] less than [%v].", count, proto.TsidTypeLen)
				packetResult.err = ErrParsePacketTransactionIdAndType
				packetResult.errCode = 12
				packetResultChan <- packetResult
			}
			break
		}
		tsidAndType := binary.BigEndian.Uint16(tsbuf)
		packetType := proto.PacketType(tsidAndType & 3)
		transID = tsidAndType >> 2
		packetResult.Tsid = transID

		if packetLen > Conf.MaxProtoSize {
			packetResult.err = ErrPacketLenExcess
			packetResult.errCode = 13
			logger.Warnf("Request packet length [%v] exceeded max [%v].", packetLen, Conf.MaxProtoSize)
			packetResultChan <- packetResult
			break
		}

		packet = &proto.Packet{Type: packetType}

		// 处理Request包
		if packetType == proto.PacketTypeRequest {
			// 解析cmd，\n结束
			cmdBuf = []byte{}
			for {
				client.Conn.SetReadDeadline(*netIoTimeout())
				b := make([]byte, 1)
				if n, err := client.Conn.Read(b); err != nil {
					logger.Warnf("client connection read error where read cmd content: %v", err)
					continue LOOPClientR
				} else if n != 0 {
					if b[0] == '\n' {
						break
					}
					cmdBuf = append(cmdBuf, b...)
				}
			}

			cmd := string(cmdBuf)
			// 解析数据部分
			dataLen := packetLen - uint32(len(cmdBuf)) - 1
			dataBuf = make([]byte, dataLen)
			err = readConnDataWithTimeout(client.Conn, dataBuf)
			if err != nil {
				logger.Warnf("client connection read error where read request data: %v", err)
				break
			}

			dataBuf, err = proto.Decrypt(dataBuf, Conf.DataEncryptKey, Conf.DataEncryptIV)
			if err != nil {
				packetResult.err = fmt.Errorf("Failed to decrypt request data for transaction[%v]: %v", transID, err)
				packetResult.errCode = 10
				packetResultChan <- packetResult
				break
			}
			packet.Req = proto.NewRequest(transID, cmd, dataBuf, Conf.DataEncryptKey, Conf.DataEncryptIV)
		} else if packetType == proto.PacketTypeResponse {
			// 处理Response包
			dataBuf = make([]byte, packetLen)
			err = readConnDataWithTimeout(client.Conn, dataBuf)
			if err != nil {
				break
			}

			dataBuf, err = proto.Decrypt(dataBuf, Conf.DataEncryptKey, Conf.DataEncryptIV)
			if err != nil {
				packetResult.err = fmt.Errorf("Failed to decrypt response data for transaction[%v]: %v", transID, err)
				packetResult.errCode = 10
				packetResultChan <- packetResult
				break
			}
			packet.Resp = proto.NewResponse(transID, dataBuf, Conf.DataEncryptKey, Conf.DataEncryptIV)
		} else if packetType == proto.PacketTypeProtoError {
			// 处理Error包
			dataBuf = make([]byte, packetLen)
			err = readConnDataWithTimeout(client.Conn, dataBuf)
			if err != nil {
				break
			}

			dataBuf, err = proto.Decrypt(dataBuf, Conf.DataEncryptKey, Conf.DataEncryptIV)
			if err != nil {
				packetResult.err = fmt.Errorf("Failed to decrypt response error data for transaction[%v]: %v", transID, err)
				packetResult.errCode = 10
				packetResultChan <- packetResult
				break
			}
			packet.ProtoErr = proto.NewProtoError(transID, dataBuf, Conf.DataEncryptKey, Conf.DataEncryptIV)
		} else {
			packetResult.err = ErrPacketInvalidType
			packetResultChan <- packetResult
			break
		}
		statusLocked = false
		readerProcessingLock.Unlock()
		packetResult.Packet = packet
		packetResultChan <- packetResult
	}
}

func processPacketResult(client *acsclient.Client, packetResult *PacketResult) {
	if loggerData != nil {
		loggerData.Debugw("processPacketResult",
			"tsid", packetResult.Tsid,
			"errCode", packetResult.errCode,
			"err", packetResult.err,
			"packet.type", packetResult.Packet.Type,
			"packet.req", packetResult.Packet.Req.ToString(),
			"packet.resp", packetResult.Packet.Resp.ToString(),
			"packet.err", packetResult.Packet.ProtoErr,
		)
	}

	switch packetResult.Packet.Type {
	case proto.PacketTypeRequest:
		go client.HandleRequest(packetResult.Packet.Req)
	case proto.PacketTypeResponse:
		go client.HandleResponse(packetResult.Packet.Resp)
	case proto.PacketTypeProtoError:
		// 处理对端返回的错误信息
		// TODO: 请求发起者(routine)知道这个错误
		go client.HandleRemoteProtoError(packetResult.Packet.ProtoErr)
	default:
		logger.Warnf("un-handlable packet type: %v", packetResult.Packet.Type)
	}
}

// readConnDataWithTimeout 适合用于大量数据读取,读取过程中如果在指定时间内读取不到任何数据则返回超时错误.
func readConnDataWithTimeout(conn net.Conn, buf []byte) error {
	offset := 0
	dataLen := len(buf)
	for offset < dataLen {
		conn.SetReadDeadline(*netIoTimeout())
		count, err := conn.Read(buf[offset:])
		if err != nil {
			nErr, ok := err.(net.Error)
			breakR := false
			if !ok || !nErr.Timeout() {
				breakR = true
			} else if count == 0 { // 读取超时,而且没有读到任何数据,判定为连接已经不可用.
				breakR = true
			} else { // 读取超时,但已经读到了部分数据,说明还有数据可以继续读取.
				offset += count
			}

			if breakR {
				return err
			}
		} else {
			offset += count
		}
	}
	return nil
}
